package cn.edu.bjtu.model.core;

import java.io.IOException;
import java.util.concurrent.Future;

import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cpu.nativecpu.NDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

import cn.edu.bjtu.api.NetworkModelService;
import cn.edu.bjtu.core.Deep4jModelType;
class MultiLayerNetworkSlave extends Slave{
	private static final long serialVersionUID = -5726275217116914260L;
	private MultiLayerNetwork mln = null;
	public MultiLayerNetworkSlave(NetworkModelService nms) {
		super(nms);
	}
	@Override
	public INDArray[] output(INDArray... input) {
		INDArray []res = new NDArray[1];
		res[1] = mln.output(input[0]);
		return res;
	}

	@Override
	public Evaluation evaluate(DataSetIterator iterator) {
		return mln.evaluate(iterator);
	}	
	
	@Override
	public void load() throws IOException {
		mln = ModelSerializer.restoreMultiLayerNetwork(getStream());
	}
	@Override
	public Deep4jModelType getActualType() {
		return Deep4jModelType.MultiLayerNetwork;
	}
	@Override
	Object get() {
		return null;
	}
	@Override
	public Future<INDArray[]> outputAsync(INDArray... input) throws Exception {
		return null;
	}
	@Override
	public Future<Evaluation> evaluateAsync(DataSetIterator iterator) throws Exception {
		return null;
	}
	
}
